import torch
from torch import nn
class DeepReservoirNet(nn.Module):
    def __init__(self, input_size=768, reservoir_size=1000, output_size=768, spectral_radius=0.9, leaky=0.3, sparsity=0.5):
        super(DeepReservoirNet, self).__init__()
        
        self.input_size = input_size
        self.reservoir_size = reservoir_size
        self.output_size = output_size
        self.spectral_radius = spectral_radius
        self.leaky = leaky
        
        self.W_in = nn.Linear(input_size, reservoir_size, bias=False)
        self.W_in.weight.requires_grad = False
        self.W_res = nn.Linear(reservoir_size, reservoir_size, bias=False)
        self.W_res.weight.requires_grad = False
        self.W_out = nn.Linear(reservoir_size, output_size)
        self.res_state = torch.zeros(1, reservoir_size)
        
        self.W_res_norm = self.compute_spectral_radius(sparsity)
        self.self_attention = nn.MultiheadAttention(input_size, 32, dropout=0.2)


    def compute_spectral_radius(self, sparsity=0.5):
        with torch.no_grad():
            self.W_res.weight.data = torch.randn(self.reservoir_size, self.reservoir_size)
            # set a fraction of the entries to zero
            num_zeros = int(sparsity * self.reservoir_size ** 2)
            idxs = torch.randperm(self.reservoir_size ** 2)[:num_zeros]
            self.W_res.weight.data.view(-1)[idxs] = 0
            
            eigenvals = torch.linalg.eigvals(self.W_res.weight)
            radius = torch.max(torch.abs(eigenvals))
            self.W_res.weight.data /= radius
        return radius
    def forward(self, input_data, res_state):
        #print()
        # Compute reservoir state
        outputs = []
        batch_size = input_data.shape[0]
        seq_length = input_data.shape[1]
        for t in range(seq_length):
            
            i_data = input_data[:, t, :]
            #print(i_data, i_data.shape)
            input_proj = self.W_in(i_data)
      
            res_proj = self.W_res(res_state)

            #print('res_state', res_state.shape)
      
            res_state = (1 - self.leaky) * res_state + self.leaky * F.tanh(input_proj + res_proj)
            
            #print( (1 - self.leaky), (0.2*res_state).shape)
            # Normalize reservoir state
            res_state = res_state / self.W_res_norm

            
            # Compute output
            output = self.W_out(res_state)
            #print(output.shape)
            # Permute output to shape (sequence_length, batch_size, output_size)
            
            output, self_attention_weights = self.self_attention(output, output, output)
            # Permute output back to shape (batch_size, sequence_length, output_size)
            #print(output.shape)
           
       
            outputs.append(output)

        return {'Output': torch.cat(outputs, dim=1), "State": res_state}

